
configfile: "config/revised1.yaml"

# singularity: "docker://alzhang/scrna-analysis-sim-3.5:v1.14"
singularity: "docker://alzhang/spectrum-master-rstudio:v1.4"

rule all:
    input:
        #expand('{workdir}/simulated_sces/{{seed}}/{{simtype}}.rdata'.format(workdir=config['workdir']), simtype=config['simulation_settings'], seed=config['seeds']),
        expand('{outdir}/cellassign_times/{{seed}}/{{simtype}}/{{evaltype}}/{{timeseed}}.rdata'.format(outdir=config['outdir']), seed=config['seeds'], simtype=config['simulation_settings'], evaltype=config['evaluation_settings'], timeseed=config['timing_seeds']),
        '{outdir}/cellassign_times_summary.tsv'.format(outdir=config['outdir']),

rule simulate_sce:
    input:
        base_sce_param=config['base_sce_param_file'],
    output:
        '{workdir}/simulated_sces/{{seed}}/{{simtype}}.rdata'.format(workdir=config['workdir']),
    params:
        name='simulate-sce-{simtype}-{seed}',
        num_groups=lambda wildcards: config['simulation_settings'][wildcards.simtype]['num_groups'],
        num_batches=lambda wildcards: config['simulation_settings'][wildcards.simtype]['num_batches'],
        num_cells=lambda wildcards: config['simulation_settings'][wildcards.simtype]['num_cells'],
        num_genes=lambda wildcards: config['simulation_settings'][wildcards.simtype]['num_genes'],
        group_probs=lambda wildcards: config['simulation_settings'][wildcards.simtype]['group_probs'],
        batch_probs=lambda wildcards: config['simulation_settings'][wildcards.simtype]['batch_probs'],
        de_facloc=lambda wildcards: config['simulation_settings'][wildcards.simtype]['de_facloc'],
        de_facscale=lambda wildcards: config['simulation_settings'][wildcards.simtype]['de_facscale'],
        de_nu=lambda wildcards: config['simulation_settings'][wildcards.simtype]['de_nu'],
        de_min=lambda wildcards: config['simulation_settings'][wildcards.simtype]['de_min'],
        de_max=lambda wildcards: config['simulation_settings'][wildcards.simtype]['de_max'],
        batch_facloc=lambda wildcards: config['simulation_settings'][wildcards.simtype]['batch_facloc'],
        batch_facscale=lambda wildcards: config['simulation_settings'][wildcards.simtype]['batch_facscale'],
        de_prob=lambda wildcards: config['simulation_settings'][wildcards.simtype]['de_prob'],
        down_prob=lambda wildcards: config['simulation_settings'][wildcards.simtype]['down_prob'],
        sim_model=lambda wildcards: config['simulation_settings'][wildcards.simtype]['sim_model'],
        seed='{seed}',
    log:
        '{logdir}/logs/simulate_sce/{{seed}}/{{simtype}}.log'.format(logdir=config['logdir']),
    benchmark:
        '{logdir}/benchmarks/simulate_sce/{{seed}}/{{simtype}}.txt'.format(logdir=config['logdir']),
    shell:
        'Rscript R/simulate_sce.R '
        '--num_groups {params.num_groups} '
        '--num_batches {params.num_batches} '
        '--num_cells {params.num_cells} '
        '--num_genes {params.num_genes} '
        '--group_probs {params.group_probs} '
        '--batch_probs {params.batch_probs} '
        '--de_facscale {params.de_facscale} '
        '--de_facloc {params.de_facloc} '
        '--de_nu {params.de_nu} '
        '--de_min {params.de_min} '
        '--de_max {params.de_max} '
        '--batch_facscale {params.batch_facscale} '
        '--batch_facloc {params.batch_facloc} '
        '--de_prob {params.de_prob} '
        '--down_prob {params.down_prob} '
        '--seed {params.seed} '
        '--base_sce_param {input.base_sce_param} '
        '--sim_model {params.sim_model} '
        '--outfname {output} '
        '>& {log}'

rule run_cellassign:
    input:
        sce='{workdir}/simulated_sces/{{seed}}/{{simtype}}.rdata'.format(workdir=config['workdir']),
    output:
        outres='{workdir}/cellassign_results/{{seed}}/{{simtype}}/{{evaltype}}/{{timeseed}}.rdata'.format(workdir=config['workdir']),
        outtime='{outdir}/cellassign_times/{{seed}}/{{simtype}}/{{evaltype}}/{{timeseed}}.rdata'.format(outdir=config['outdir']),
    params:
        name='run-cellassign-{simtype}-{evaltype}-{seed}-{timeseed}',
        max_genes=lambda wildcards: config['evaluation_settings'][wildcards.evaltype]['max_genes'],
        fc_percentile=lambda wildcards: config['evaluation_settings'][wildcards.evaltype]['fc_percentile'],
        expr_percentile=lambda wildcards: config['evaluation_settings'][wildcards.evaltype]['expr_percentile'],
        timeseed='{timeseed}',
        max_batch_size=5000,
        conda_env='r-tensorflow',
    threads: 20
    log:
        '{logdir}/logs/run_cellassign/{{simtype}}/{{evaltype}}/{{seed}}_{{timeseed}}.log'.format(logdir=config['logdir']),
    benchmark:
        '{logdir}/benchmarks/run_cellassign/{{simtype}}/{{evaltype}}/{{seed}}_{{timeseed}}.txt'.format(logdir=config['logdir']),
    shell:
        'Rscript R/benchmark_cellassign.R '
        '--sce {input.sce} '
        '--fc_percentile {params.fc_percentile} '
        '--expr_percentile {params.expr_percentile} '
        '--max_genes {params.max_genes} '
        '--seed {params.timeseed} '
        '--conda_env {params.conda_env} '
        '--max_batch_size {params.max_batch_size} '
        '--outfname {output.outres} '
        '--outtime {output.outtime} '
        '>& {log}'

rule collate_results:
    input:
        timefiles=expand('{outdir}/cellassign_times/{{seed}}/{{simtype}}/{{evaltype}}/{{timeseed}}.rdata'.format(outdir=config['outdir']), seed=config['seeds'], simtype=config['simulation_settings'], evaltype=config['evaluation_settings'], timeseed=config['timing_seeds']),
    output:
        '{outdir}/cellassign_times_summary.tsv'.format(outdir=config['outdir']),
    params:
        name='collate-results',
        benchmark_config=config['benchmark_config_file'],
    log:
        '{logdir}/logs/collate_results.log'.format(logdir=config['logdir']),
    benchmark:
        '{logdir}/benchmarks/collate_results.txt'.format(logdir=config['logdir']),
    shell:
        'Rscript R/collate_results.R '
        '--timing_files {input.timefiles} '
        '--benchmark_config {params.benchmark_config} '
        '--outfname {output} '
        '>& {log}'